{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Differentiable Shallow Water Equations\n",
    "\n",
    "We present a differentiable SWE solver, based on `paddle-harmonics`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import paddle.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "\n",
    "from math import ceil, floor\n",
    "\n",
    "import sys\n",
    "\n",
    "from paddle_harmonics.sht import *\n",
    "from paddle_harmonics.examples import ShallowWaterSolver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = paddle.device.set_device('gpu' if paddle.device.cuda.device_count() > 0 else 'cpu')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define a shallow water solver class in `shallow_water_equations.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize parameters:\n",
    "nlat = 512\n",
    "nlon = 2*nlat\n",
    "lmax = ceil(128)\n",
    "mmax = lmax\n",
    "# timestepping\n",
    "dt = 75\n",
    "maxiter = 12*int(86400/dt)\n",
    "\n",
    "# initialize solver class\n",
    "swe_solver = ShallowWaterSolver(nlat, nlon, dt, lmax=lmax, mmax=mmax).to(device)\n",
    "\n",
    "lons = swe_solver.lons\n",
    "lats = swe_solver.lats\n",
    "\n",
    "jj, ii = paddle.triu_indices(lmax, mmax)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "uspec0 = swe_solver.galewsky_initial_condition()    "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now ready to run the simulation. To perform integration in time, we will use third-order Adams-Bashforth. As we are currently not interested in gradients, we can wrap the function in `paddle.np_grad()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dudtspec = paddle.zeros([3, 3, swe_solver.lmax, swe_solver.mmax], dtype=paddle.complex128)\n",
    "inew = 0\n",
    "inow = 1\n",
    "iold = 2\n",
    "\n",
    "uspec = uspec0.clone().to(device)\n",
    "\n",
    "# save for later:\n",
    "nskip = 50\n",
    "utspec = paddle.zeros([floor(maxiter//nskip) + 1, *uspec.shape]).to(paddle.complex128)\n",
    "\n",
    "with paddle.no_grad():\n",
    "    for iter in range(maxiter+1):\n",
    "        t = iter*dt\n",
    "\n",
    "        if iter % nskip == 0:\n",
    "            utspec[iter//nskip] = uspec\n",
    "            print(f\"t={t/3600:.2f} hours\")\n",
    "\n",
    "        dudtspec[inew] = swe_solver.dudtspec(uspec)\n",
    "        \n",
    "        # update vort,div,phiv with third-order adams-bashforth.\n",
    "        # forward euler, then 2nd-order adams-bashforth time steps to start.\n",
    "        if iter == 0:\n",
    "            dudtspec[inow] = dudtspec[inew]\n",
    "            dudtspec[iold] = dudtspec[inew]\n",
    "        elif iter == 1:\n",
    "            dudtspec[iold] = dudtspec[inew]\n",
    "\n",
    "        uspec = uspec + swe_solver.dt*( (23./12.) * dudtspec[inew] - (16./12.) * dudtspec[inow] + (5./12.) * dudtspec[iold] )\n",
    "\n",
    "        # implicit hyperdiffusion for vort and div.\n",
    "        uspec[1:] = swe_solver.hyperdiff * uspec[1:]\n",
    "        # cycle through the indices\n",
    "        inew = (inew - 1) % 3\n",
    "        inow = (inow - 1) % 3\n",
    "        iold = (iold - 1) % 3\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "im = swe_solver.plot_specdata(uspec[1], fig, cmap=\"twilight_shifted\")\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plotting a video\n",
    "\n",
    "let us plot the vorticity for our rollout:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p ./plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare figure for animation\n",
    "fig = plt.figure(figsize=(8, 6), dpi=72)\n",
    "moviewriter = animation.writers['pillow'](fps=20)\n",
    "moviewriter.setup(fig, './plots/zonal_jet.gif', dpi=72)\n",
    "\n",
    "plot_pvrt = False\n",
    "\n",
    "for i in range(utspec.shape[0]):\n",
    "    t = i*nskip*dt\n",
    "\n",
    "    if plot_pvrt:\n",
    "        variable = swe_solver.potential_vorticity(utspec[i])\n",
    "    else:\n",
    "        variable = swe_solver.spec2grid(utspec[i, 1])\n",
    "\n",
    "    plt.clf()\n",
    "    # swe_solver.plot_griddata(variable, cmap=cmap, vmin=-0.2, vmax=1.8, title=f'zonal jet t={t/3600:.2f} hours')\n",
    "    swe_solver.plot_griddata(variable, fig, cmap=\"twilight_shifted\", antialiased=False)\n",
    "    plt.draw()\n",
    "    moviewriter.grab_frame()\n",
    "\n",
    "\n",
    "moviewriter.finish()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conservation of potential vorticity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pvrttspec = paddle.zeros([floor(maxiter//nskip) + 1, lmax, mmax]).to(paddle.complex128)\n",
    "for i in range(utspec.shape[0]):\n",
    "    pvrttspec[i] = swe_solver.grid2spec(swe_solver.potential_vorticity(utspec[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_vrt = pvrttspec.abs()**2\n",
    "# total_vrt = utspec[..., 1, :, :].abs()**2\n",
    "total_vrt = paddle.sqrt(paddle.sum(total_vrt[..., :1], axis=(-1, -2)) + paddle.sum(2 * total_vrt[..., 1:], axis=(-1, -2))).numpy()\n",
    "t = nskip*dt * paddle.arange(utspec.shape[0])\n",
    "t = t.numpy()\n",
    "\n",
    "plt.plot(t, total_vrt / total_vrt[0], label='Spectral Solver')\n",
    "plt.title('Total vorticity over time')\n",
    "plt.ylim((0,1))\n",
    "plt.legend(loc='lower left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
